// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
// RUN: mlir-opt %s -sparsification | FileCheck %s

#SparseMatrix = #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>

// A contrived example that demonstrates the many different ways
// in which scalar values can be involved in a sparse kernel
// through the linalg generic op.

#trait = {
  indexing_maps = [
    affine_map<(i,j) -> (i,j)>,  // A (sparse tensor)
    affine_map<(i,j) -> ()>,     // p (scalar tensor)
    affine_map<(i,j) -> ()>,     // q (true scalar)
    affine_map<(i,j) -> (i,j)>   // X (dense tensor out)
  ],
  iterator_types = ["parallel", "parallel"],
  doc = "X(i,j) += A(i,j) * p * q * r * s * 2.2"
}

// CHECK-LABEL:   func @mul(
// CHECK-SAME:              %[[VAL_0:.*0]]: tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>>,
// CHECK-SAME:              %[[VAL_1:.*1]]: tensor<f32>,
// CHECK-SAME:              %[[VAL_2:.*2]]: f32,
// CHECK-SAME:              %[[VAL_3:.*3]]: f32,
// CHECK-SAME:              %[[VAL_4:.*4]]: tensor<32x16xf32> {linalg.inplaceable = true}) -> tensor<32x16xf32> {
// CHECK:           %[[VAL_5:.*]] = constant 2.200000e+00 : f32
// CHECK:           %[[VAL_6:.*]] = constant 0 : index
// CHECK:           %[[VAL_7:.*]] = constant 1 : index
// CHECK:           %[[VAL_8:.*]] = addf %[[VAL_2]], %[[VAL_3]] : f32
// CHECK:           %[[VAL_9:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_6]] : tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>> to memref<?xindex>
// CHECK:           %[[VAL_10:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_6]] : tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>> to memref<?xindex>
// CHECK:           %[[VAL_11:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_7]] : tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>> to memref<?xindex>
// CHECK:           %[[VAL_12:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_7]] : tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>> to memref<?xindex>
// CHECK:           %[[VAL_13:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>> to memref<?xf32>
// CHECK:           %[[VAL_14:.*]] = memref.buffer_cast %[[VAL_1]] : memref<f32>
// CHECK:           %[[VAL_15:.*]] = memref.buffer_cast %[[VAL_4]] : memref<32x16xf32>
// CHECK:           %[[VAL_16:.*]] = memref.load %[[VAL_14]][] : memref<f32>
// CHECK:           %[[VAL_17:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_6]]] : memref<?xindex>
// CHECK:           %[[VAL_18:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_7]]] : memref<?xindex>
// CHECK:           scf.for %[[VAL_19:.*]] = %[[VAL_17]] to %[[VAL_18]] step %[[VAL_7]] {
// CHECK:             %[[VAL_20:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_19]]] : memref<?xindex>
// CHECK:             %[[VAL_21:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_19]]] : memref<?xindex>
// CHECK:             %[[VAL_22:.*]] = addi %[[VAL_19]], %[[VAL_7]] : index
// CHECK:             %[[VAL_23:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_22]]] : memref<?xindex>
// CHECK:             scf.for %[[VAL_24:.*]] = %[[VAL_21]] to %[[VAL_23]] step %[[VAL_7]] {
// CHECK:               %[[VAL_25:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_24]]] : memref<?xindex>
// CHECK:               %[[VAL_26:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_24]]] : memref<?xf32>
// CHECK:               %[[VAL_27:.*]] = mulf %[[VAL_26]], %[[VAL_16]] : f32
// CHECK:               %[[VAL_28:.*]] = mulf %[[VAL_27]], %[[VAL_2]] : f32
// CHECK:               %[[VAL_29:.*]] = mulf %[[VAL_28]], %[[VAL_3]] : f32
// CHECK:               %[[VAL_30:.*]] = mulf %[[VAL_29]], %[[VAL_8]] : f32
// CHECK:               %[[VAL_31:.*]] = mulf %[[VAL_30]], %[[VAL_5]] : f32
// CHECK:               %[[VAL_32:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_20]], %[[VAL_25]]] : memref<32x16xf32>
// CHECK:               %[[VAL_33:.*]] = addf %[[VAL_31]], %[[VAL_32]] : f32
// CHECK:               memref.store %[[VAL_33]], %[[VAL_15]]{{\[}}%[[VAL_20]], %[[VAL_25]]] : memref<32x16xf32>
// CHECK:             }
// CHECK:           }
// CHECK:           %[[VAL_34:.*]] = memref.tensor_load %[[VAL_15]] : memref<32x16xf32>
// CHECK:           return %[[VAL_34]] : tensor<32x16xf32>
// CHECK:         }
func @mul(%arga: tensor<32x16xf32, #SparseMatrix>,
          %argp: tensor<f32>,
          %argq: f32,
          %argr: f32,
          %argx: tensor<32x16xf32> {linalg.inplaceable = true}) -> tensor<32x16xf32> {
  %s = addf %argq, %argr : f32
  %c = constant 2.2 : f32
  %0 = linalg.generic #trait
     ins(%arga, %argp, %argq: tensor<32x16xf32, #SparseMatrix>, tensor<f32>, f32)
    outs(%argx: tensor<32x16xf32>) {
      ^bb(%a: f32, %p: f32, %q: f32, %x: f32):
        %0 = mulf %a, %p : f32     // scalar tensor argument
        %1 = mulf %0, %q : f32     // scalar argument
        %2 = mulf %1, %argr : f32  // scalar argument from outside block
        %3 = mulf %2, %s : f32     // scalar value from outside block
        %4 = mulf %3, %c : f32     // direct constant from outside block
        %5 = addf %4, %x : f32
        linalg.yield %5  : f32
  } -> tensor<32x16xf32>

  return %0 : tensor<32x16xf32>
}
